#include "dnnl.hpp"

using namespace dnnl;

using tag = memory::format_tag;
using dt = memory::data_type;

extern "C"
void winconv(float *__restrict__ image, const int irows, const int icols,
             const int C, float *__restrict__ filter, const int K,
             const int batch, float *__restrict__ out) {
    dnnl::engine engine(dnnl::engine::kind::cpu, 0);
    dnnl::stream engine_stream(engine);
    const memory::dim N = batch,
        IH = irows, IW = icols, OH = IH - 2, OW = IW - 2,
        IC = C, OC = K, KH = 3, KW = 3;
    memory::dims src_dims = {N, IC, IH, IW};
    memory::dims weights_dims = {OC, IC, KH, KW};
    memory::dims dst_dims = {N, OC, OH, OW};
    memory::dims strides_dims = {1, 1};
    memory::dims padding_dims_l = {0, 0};
    memory::dims padding_dims_r = {0, 0};
    // Create memory objects for tensor data (src, weights, dst)
    auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine, image);
    auto user_weights_mem = memory({weights_dims, dt::f32, tag::oihw}, engine, filter);
    auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine, out);
    // Create memory descriptors with format_tag::any for the primitive
    auto conv_src_md = memory::desc(src_dims, dt::f32, tag::any);
    auto conv_weights_md = memory::desc(weights_dims, dt::f32, tag::any);
    auto conv_dst_md = memory::desc(dst_dims, dt::f32, tag::any);
    // Create operation descriptor
    // === MODIFICATION NOTE === change `convolution_auto` -> `convolution_winograd` for avx512
#ifdef WINOGRAD
#define CNN_ALGO algorithm::convolution_winograd
#else
#define CNN_ALGO algorithm::convolution_direct
#endif
    auto conv_desc = convolution_forward::desc(prop_kind::forward_inference,
           CNN_ALGO, conv_src_md, conv_weights_md,
           conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
    // Create primitive descriptor.
    auto conv_pd = convolution_forward::primitive_desc(conv_desc, engine);
    // Create memory descriptors with format_tag::any for the primitive
    auto conv_src_mem = user_src_mem;
    auto conv_weights_mem = user_weights_mem;
    auto conv_dst_mem = user_dst_mem;
    // Reorder the data in case memory layouts are different
    if (conv_pd.src_desc() != user_src_mem.get_desc()) {
        conv_src_mem = memory(conv_pd.src_desc(), engine);
        reorder(user_src_mem, conv_src_mem)
                .execute(engine_stream, user_src_mem, conv_src_mem);
    }
    if (conv_pd.weights_desc() != user_weights_mem.get_desc()) {
        conv_weights_mem = memory(conv_pd.weights_desc(), engine);
        reorder(user_weights_mem, conv_weights_mem)
                .execute(engine_stream, user_weights_mem, conv_weights_mem);
    }
    if (conv_pd.dst_desc() != user_dst_mem.get_desc()) {
        conv_dst_mem = memory(conv_pd.dst_desc(), engine);
    }
    // Create the primitive.
    auto conv_prim = convolution_forward(conv_pd);
    // Primitive arguments.
    std::unordered_map<int, memory> conv_args;
    conv_args.insert({DNNL_ARG_SRC, conv_src_mem});
    conv_args.insert({DNNL_ARG_WEIGHTS, conv_weights_mem});
    conv_args.insert({DNNL_ARG_DST, conv_dst_mem});
    // Primitive execution: convolution
    conv_prim.execute(engine_stream, conv_args);
    if (conv_pd.dst_desc() != user_dst_mem.get_desc()) {
        reorder(conv_dst_mem, user_dst_mem)
                .execute(engine_stream, conv_dst_mem, user_dst_mem);
    } else
        user_dst_mem = conv_dst_mem;
    // Wait for the computation to finalize.
    engine_stream.wait();
}
